-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Large Tensor] Fixed Spatial Transformer op #17617
[Large Tensor] Fixed Spatial Transformer op #17617
Conversation
@mxnet-label-bot add [pr-awaiting-review] |
@connorgoggins can you tell me what are the possible types for DType here: |
@access2rohit DType here represents the type of the interior elements of the input tensors, and since SpatialTransformer only supports floating point types then DType could be |
c5ff973
to
871849e
Compare
index_t top_right_v = 0; | ||
index_t bottom_left_v = 0; | ||
index_t bottom_right_v = 0; | ||
DType top_left_v = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this changed back to Dtype?
It's the index position right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChaiBapchya changing DType to index_t for these specific variables makes the op generate incorrect output on standard inputs (e.g. the inputs in the CI run) - the values generated in the output NDArray are all integers instead of floats. This is due to the fact that these variables do not represent the index positions (as I also originally believed), but instead represent the underlying values at the vertices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. good catch.
_v indicates the value at that particular index. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
d3696f2
to
50d1530
Compare
50d1530
to
5ac85a6
Compare
* Added CPU fix * Added fix for backward on CPU * Fixed lint error * index_t for lower bound instead of hardcoded long int * Fixed remaining lint errors * Removed trailing whitespace * Reverting to DType for vertices * Added nightly test for SpatialTransformer
Description
The Spatial Transformer op was previously breaking on large tensor (dimension >= 2^32) data. With the following input:
the following error was thrown:
To root cause this issue, I ran the previous command in a Python script with GDB, and found that the underlying problem was in the iteration portion of the forward and backward methods of
spatial_transformer.cc
. Several of the variables used in the iteration used theint
dtype when they should have been usingindex_t
to properly handle long int indices. I switched these variables toindex_t
in the forward and backward methods, and after rebuilding, the previous input command displayed the correct output:Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
Tested on r5dn.24xl-ubuntu 16.04 and p2.16xl-ubuntu 16.04 with
Results
The key difference between CPU and GPU tests was the instance type (r5dn.24xl for CPU, p2.16xl for GPU). All relevant build flags remain the same, and both were tested using CPU context.
Single operator test - SpatialTransformer op (GPU)
Single operator test - SpatialTransformer op (CPU)
Full OpPerf test (GPU)
Full OpPerf test (CPU)
@apeforest @access2rohit @ChaiBapchya